home *** CD-ROM | disk | FTP | other *** search
/ Tech Arsenal 1 / Tech Arsenal (Arsenal Computer).ISO / tek-04 / nasanets.zip / TEACH.C < prev    next >
C/C++ Source or Header  |  1990-06-07  |  20KB  |  460 lines

  1. /*=============================*/
  2. /*           NETS              */
  3. /*                             */
  4. /* a product of the AI Section */
  5. /* NASA, Johnson Space Center  */
  6. /*                             */
  7. /* principal author:           */
  8. /*       Paul Baffes           */
  9. /*                             */
  10. /* contributing authors:       */
  11. /*      Bryan Dulock           */
  12. /*      Chris Ortiz            */
  13. /*=============================*/
  14.  
  15.  
  16. /*
  17. ----------------------------------------------------------------------
  18.   Code For Teaching Networks (Prefix = T_)
  19. ----------------------------------------------------------------------
  20.   This code is divided into 4 major sections:
  21.  
  22.   (1) include files
  23.   (2) externed functions
  24.   (3) externed global variables
  25.   (4) subroutines
  26.  
  27.   Each section is further explained below.
  28. ----------------------------------------------------------------------
  29. */
  30.  
  31.  
  32. /*
  33. ----------------------------------------------------------------------
  34.   INCLUDE FILES
  35. ----------------------------------------------------------------------
  36. */
  37. #include "common.h"
  38. #include "weights.h"
  39. #include "layer.h"
  40. #include "net.h"
  41.  
  42.  
  43. /*
  44. ----------------------------------------------------------------------
  45.   EXTERNED FUNCTIONS
  46. ----------------------------------------------------------------------
  47.   Below are the functions defined in other files which are used by the
  48.   code here. They are organized by section.
  49. ----------------------------------------------------------------------
  50. */
  51. extern  void         L_update_learn_rates();
  52.  
  53. extern  void         PA_rewind_workfile();
  54. extern  void         PA_done_with_workfile();
  55. extern  Sint         PA_retrieve();
  56.  
  57. extern  Sint         C_float_to_Sint();
  58. extern  float        C_Sint_to_float();
  59.  
  60. extern  void         P_prop_input();
  61. extern  void         P_prop_error();
  62. extern  Sint         P_calc_delta();
  63. extern  void         P_change_biases();
  64.  
  65. extern  void         D_dribble();
  66. extern  float        sys_get_time();
  67. extern  void         IO_print();
  68. extern  void         IO_insert_format();
  69.  
  70.  
  71.  
  72. /*
  73. ----------------------------------------------------------------------
  74.   EXTERNED GLOBALS
  75. ----------------------------------------------------------------------
  76.   Finally, some globals imported from the NETMAIN.C code which drive
  77.   the dribble option of NETS. This is a fairly new option (3-89) which
  78.   allows you to save the errors, weights, or the results of a test 
  79.   input while training a network.
  80.   
  81.   All that I need to know here is which of the flags for dribbling (if
  82.   any) are set. There are three of these. If any are set, then I simply
  83.   call the D_dribble routine to handle the rest.
  84. ----------------------------------------------------------------------
  85. */
  86. extern int    SAVE_MAXERR;   /* all below are from the IO file */
  87. extern int    SAVE_RMSERR;
  88. extern int    SAVE_WEIGHTS;
  89. extern int    SAVE_TESTS;
  90. extern char   IO_str[MAX_LINE_SIZE];
  91. extern char   IO_wkstr[MAX_LINE_SIZE];
  92.  
  93.  
  94. /*
  95. ======================================================================
  96.   ROUTINES IN TEACH.C                                                   
  97. ======================================================================
  98.   The routines in this file are grouped below by function.  Each routine
  99.   is prefixed by the string "T_" indicating that it is defined in the 
  100.   "teach.c" file.  The types returned by the routines are also shown here
  101.   so that cross checking is more easily done between these functions
  102.   and the other files which intern them.
  103.  
  104.  
  105.   Type Returned                 Routine                                 
  106.   -------------                 -------                                 
  107.                                                                      
  108.   CREATION ROUTINES                                                    
  109.     void                        T_teach_net                             
  110.     Sint                        T_process_iopairs                       
  111.     void                        T_setup_inputs                          
  112.     Sint                        T_setup_errors     
  113.     void                        T_change_weights 
  114. ======================================================================
  115. */
  116.  
  117.  
  118. void  T_teach_net(ptr_net, desired_err, stop_cycle, print_mod)
  119. Net   *ptr_net;
  120. Sint  desired_err;
  121. int   stop_cycle;
  122. int   print_mod;
  123. /*
  124. ----------------------------------------------------------------------
  125.  'teach_net' is really the guts of this entire program, since most of 
  126.   the time spent running this code will be spent here and since the   
  127.   success or failure of the code rests here.  Because so much running 
  128.   time will be spent in this routine, it is imperative that it be     
  129.   efficient and much of the design of the data structures for the net 
  130.   went into making this possible.                                     
  131.  The basic idea is to follow the algorithm set down in the Rummelhart 
  132.   et. al. paper on the generalized delta approach to neural nets.     
  133.   Here, the idea is to take an input, propagate it through the net,   
  134.   measure the subsequent output against your desired output, and then 
  135.   propagate the error back through the net.  This process is repeated 
  136.   for each IO pair in turn until such time as ALL the io pairs have   
  137.   been learned within an acceptable error.                            
  138. ----------------------------------------------------------------------
  139.  4-7-89  I added another parameter called "print_mod" which gets passed
  140.   in from netmain.c to be used when printing out the Max error/RMS error
  141.   message to the screen. The print_mod indicates how many training 
  142.   cycles should elapse between printouts to the screen. Note that this
  143.   printout statement used to be in the T_process_iopairs routine and 
  144.   was moved here because this routine keeps the cycle number. Thus, all
  145.   we need do is check the num_cycles against the print_mod to see if 
  146.   the errors should be printed out.
  147. ----------------------------------------------------------------------
  148. */
  149. BEGIN
  150.    Sint   cur_error, T_process_iopairs();
  151.    int    num_cycles;
  152.    float  t1, t2, grad_err;
  153.  
  154.    cur_error  = desired_err * 2;
  155.    num_cycles = 0;
  156.    sprintf(IO_str, "\n*** Learning; please wait ***\n");
  157.    IO_print(0);
  158.  
  159.    t1 = sys_get_time();                             /* record starting time */
  160.    while (cur_error > desired_err) BEGIN
  161.       num_cycles++;
  162.       if (num_cycles > stop_cycle) BEGIN
  163.          sprintf(IO_str, "\n*** Timeout failure ***\n");
  164.          IO_print(0);
  165.          sprintf(IO_str, "*** Net could not learn after %d tries ***\n", 
  166.                 stop_cycle);
  167.          IO_print(0);
  168.          break;
  169.       ENDIF
  170.       cur_error = T_process_iopairs(ptr_net, &grad_err);
  171.  
  172.       /*------------------------*/
  173.       /* print errors to screen */
  174.       /*------------------------*/
  175.       if (print_mod != 0)
  176.          if ((num_cycles % print_mod) == 0) BEGIN
  177.             sprintf(IO_str, "\n Cycle : %d", num_cycles);
  178.             IO_print(0);
  179.             sprintf(IO_wkstr, "   Max error:%%.f    RMS error:%%.f \n");
  180.             IO_insert_format(IO_wkstr);
  181.             sprintf(IO_str, IO_wkstr, C_Sint_to_float(cur_error), grad_err);
  182.             IO_print(0);
  183.          ENDIF
  184.       
  185.       /*-----------------------------------------*/
  186.       /* if dribble parameters set, then dribble */
  187.       /*-----------------------------------------*/
  188.       if (SAVE_MAXERR == TRUE || SAVE_RMSERR == TRUE 
  189.           || SAVE_WEIGHTS == TRUE || SAVE_TESTS == TRUE)
  190.          D_dribble(ptr_net, num_cycles, cur_error, grad_err);
  191.    ENDWHILE
  192.    
  193.    t2 = sys_get_time();                             /* record ending time   */
  194.    PA_done_with_workfile();                       /* make sure file closed*/
  195.    if (num_cycles <= stop_cycle) BEGIN
  196.       sprintf(IO_str, "\nNet learned after %d cycles\n", num_cycles);
  197.       IO_print(0);
  198.       sprintf(IO_str, "Learning time: %7.1f seconds\n", (t2 - t1));
  199.       IO_print(0);
  200.    ENDIF
  201.  
  202. END /* T_teach_net */
  203.  
  204.  
  205. Sint  T_process_iopairs(ptr_net, grad_err)
  206. Net    *ptr_net;
  207. float  *grad_err;
  208. /*
  209. ----------------------------------------------------------------------
  210.  Steps through each of the io pairs once, propagating the input, then 
  211.   propagating back the error.  At each point, a maximum error is      
  212.   determined for the eventual return value from this routine.         
  213.  There may be some confusion over the terminology concerning 'errors' 
  214.   returned from this routine.  Actually, there are two types of error 
  215.   values we are concerned with.  One is used as a measure of when to  
  216.   to stop the processing, and the other is a normalized, overall error
  217.   of how well the net currently knows the entire set of IO pairs. The 
  218.   second number is more for diagnostic purposes than anything else,   
  219.   but it is significant in that the whole theory behind the Rummelhart
  220.   et. al. paper is that this algorithm is guaranteed to minimize the  
  221.   second measure of error.                                            
  222.  To keep things straight, I will call the two errors "stop_err" and   
  223.   "grad_err" since the first measures when the process should stop and
  224.   the second is keeping track of what ought to be a gradient descent. 
  225. ----------------------------------------------------------------------
  226.  I just added (3-16-89) another variable called "avg_err" which is used
  227.   as a pointer to a Sint which will be used to hold the average error
  228.   for A SINGLE IO PAIR. The idea is to keep the average error for one
  229.   input/output pair and then use this average to reset the learning rate
  230.   for that particular IO pair. A guy named Don Woods (MacDonnel Doug.) 
  231.   came up with the idea which he tested on an XOR net. I am simply 
  232.   extending the idea to the general case (ie, multiple outputs) in an
  233.   attempt to help the learning. A problem seems to arise with NETS 
  234.   due to the fact that the deltas calculated can only get so small 
  235.   before no weight changes are made. That is:
  236.   
  237.      delta weight = learn_rate * delta(j) * output(i) 
  238.      
  239.   for a weight connecting node i to output node j. If the learn_rate is
  240.   small (say .1) and the output if i is average (say .5) then the delta
  241.   must be .02 to generate a weight change of .001 (our minimum precision).
  242.   Now:
  243.   
  244.      delta(j) = (t - o) * o(1-o)
  245.      
  246.   for the output deltas. With a desired delta of .02, we either have
  247.   to have large (t-o) values, or large o values. That means NETS will
  248.   stop learning (using a learn_rate of .1) when it starts to get close!
  249.   The use of the avg err should help fix that.
  250.  The avg_err is passed to T_change_weights so that each layer may 
  251.   calculate its own learning rate based on the average error.
  252. ----------------------------------------------------------------------
  253.  4-7-89 Note that I have moved the printout statement for the errors 
  254.   to the T_teach_net routine so that a modulus argument could be used
  255.   for specification of how often the errors should be printed.
  256. ----------------------------------------------------------------------
  257. */
  258. BEGIN
  259.    void    T_setup_inputs(), T_change_weights();
  260.    Sint    stop_err, temp_s_err, avg_err, T_setup_errors();
  261.    int     i;
  262.  
  263.    stop_err = 0;
  264.    *grad_err = 0;
  265.    PA_rewind_workfile();
  266.    for (i = 0; i < ptr_net->num_io_pairs; i++) BEGIN
  267.       T_setup_inputs(ptr_net);
  268.       P_prop_input(ptr_net);
  269.       temp_s_err = T_setup_errors(ptr_net, grad_err, &avg_err);
  270.       if (temp_s_err > stop_err)
  271.          stop_err = temp_s_err;
  272.       
  273.       P_prop_error(ptr_net);
  274.       L_update_learn_rates(ptr_net->hidden_front, avg_err);
  275.       T_change_weights(ptr_net);
  276.       if (ptr_net->use_biases == TRUE)
  277.          P_change_biases(ptr_net->hidden_front);
  278.    ENDFOR
  279.    
  280.    *grad_err = sqrt( (*grad_err / ((float)ptr_net->output_layer->num_nodes
  281.                                   * (float)ptr_net->num_io_pairs)) );
  282.    return(stop_err);
  283.  
  284. END /* T_process_iopairs */
  285.  
  286.  
  287. void  T_setup_inputs(ptr_net)
  288. Net  *ptr_net;
  289. /*
  290. ----------------------------------------------------------------------
  291.  Reads in the correct number of inputs from the temporary file setup  
  292.   by the 'PA_parse_iopairs' routine above, and places these values as  
  293.   the outputs of the INPUT layer of the net (pointed to by 'ptr_net') 
  294.   Note that it can be assumed that this intermediate file is in the   
  295.   proper format since the 'PA_parse_iopairs' routine must be run       
  296.   successfully prior to this routine.                                 
  297. ----------------------------------------------------------------------
  298. */
  299. BEGIN
  300.    int  i;
  301.  
  302.    for (i = 0; i < ptr_net->input_layer->num_nodes; i++)
  303.       ptr_net->input_layer->node_outputs[i] = PA_retrieve();
  304.    
  305. END /* T_setup_inputs */
  306.  
  307.  
  308. Sint  T_setup_errors(ptr_net, ptr_grad_err, ptr_avg_err)
  309. Net    *ptr_net;
  310. float  *ptr_grad_err;
  311. Sint   *ptr_avg_err;
  312. /*
  313. ----------------------------------------------------------------------
  314.  This routine is very similar to the 'T_setup_inputs' routine above,  
  315.   differing in that it works with the outputs rather than the inputs. 
  316.   Here, the process involves looking at what the net generated, called
  317.   the observed outputs, vs. what you wanted the net to generate,      
  318.   called the target outputs.  Of course, the target outputs are just  
  319.   the output part of the IO pair.  Once these two values are obtained 
  320.   our error, called new_error, becomes the difference between the two 
  321.   values.  The motivation behind calculating  this error (and then the
  322.   delta for the node) is explained in the Rummelhart et. al. paper.   
  323.  Besides calculating the errors and deltas for each output node, this 
  324.   routine also has to send back a message indicating whether or not   
  325.   the error was greater than the desired error.  Actually, this guy   
  326.   just sends back whatever the maximum error was, leaving the job of  
  327.   determining the maximum to the 'T_process_iopairs' routine in net.c 
  328.   Note, however, that since we are doing a simple subtraction to find 
  329.   our error we can get both positive and negative errors.  In order   
  330.   to make the job of determining the maximum easier, this routine     
  331.   always sends back the ABSOLUTE VALUE of the calculated error.       
  332.  Referring back to the T_process_iopairs routine, note that there are 
  333.   two different measures of error, the sort mentioned above and a     
  334.   second type to track gradient descent.  Since this routine can only 
  335.   return one type, the second sort of error is communicated via the   
  336.   pointer to the floating point value passed to this routine.         
  337. ----------------------------------------------------------------------
  338. */
  339. BEGIN
  340.    Sint   result, target, observed, new_error;
  341.    float  t_float;
  342.    int    i;
  343.  
  344.    /*------------------------------------*/
  345.    /* clear the result and average error */
  346.    /*------------------------------------*/
  347.    result = 0;
  348.    *ptr_avg_err = 0;
  349.    
  350.    /*---------------------------*/
  351.    /* loop for all output nodes */
  352.    /*---------------------------*/
  353.    for (i = 0; i < ptr_net->output_layer->num_nodes; i++) BEGIN
  354.    
  355.       /*----------------------------------*/
  356.       /* get the desired output; subtract */
  357.       /* what you got to set new_error    */
  358.       /*----------------------------------*/
  359.       target = PA_retrieve();
  360.       observed = ptr_net->output_layer->node_outputs[i];
  361.       new_error = target - observed;
  362.       
  363.       /*--------------------------------*/
  364.       /* increment the avg error by the */
  365.       /* new error (ie, running total)  */
  366.       /*--------------------------------*/
  367. #if  USE_SCALED_INTS
  368.       *ptr_avg_err += abs(new_error);
  369. #else
  370.       *ptr_avg_err += fabs(new_error);
  371. #endif
  372.       
  373.       /*-----------------------------------*/
  374.       /* add the square of the error to    */
  375.       /* the grad_err variable. These will */
  376.       /* be summed, divided, then sqrt by  */
  377.       /* the T_process_iopairs routine     */
  378.       /*-----------------------------------*/
  379.       t_float = C_Sint_to_float(new_error);
  380.       *ptr_grad_err += (t_float * t_float);
  381.       ptr_net->output_layer->node_deltas[i] =
  382.                                 P_calc_delta( (D_Sint)new_error, observed );
  383.                                 
  384.       /*---------------------------------------------*/
  385.       /* save new error as max if larger than result */
  386.       /*---------------------------------------------*/
  387. #if  USE_SCALED_INTS
  388.       if (abs(new_error) > result)
  389.          result = abs(new_error);
  390. #else
  391.       if (fabs(new_error) > result)
  392.          result = fabs(new_error);
  393. #endif
  394.    ENDFOR
  395.    
  396.    /*-------------------------------------*/
  397.    /* don't forget to divide avg error by */
  398.    /* number of outputs which is equal to */
  399.    /* the "i" value after the loop!!      */
  400.    /*-------------------------------------*/
  401.    *ptr_avg_err = *ptr_avg_err / ((Sint) i);
  402.    return(result);
  403.  
  404. END /* T_setup_errors */
  405.  
  406.  
  407. void  T_change_weights(ptr_net)
  408. Net     *ptr_net;
  409. /*
  410. ----------------------------------------------------------------------
  411.  Propagates back through the net, changing the weights to reflect the 
  412.   newest changes in the delta values.                                 
  413.  Changing the weights in the net is similar to the two propagate      
  414.   functions above in that all of the layers must be visited. However, 
  415.   since our weights are connected to both their source and target     
  416.   layers, some care has to be taken to ensure that all the weights are
  417.   visited ONLY ONCE.  To do this, you can either propagate forward or 
  418.   backward through the net, and I just happened to choose backward as 
  419.   the way to go.  Thus, I start at the second to last layer (layer 1, 
  420.   the output, is the last layer) and work my way back to layer 0.     
  421.   Each layer is considered a SOURCE as it is visited, thus only the   
  422.   out_weights to its TARGETS are updated for the layer.  This keeps a 
  423.   set of weights from being updated twice.                            
  424.  Once a set of weights is found, then the 'w_update' routine is     
  425.   called to do the work of changing the weight values.
  426.  I have added another parameter to this routine called "avg_err" which
  427.   holds the average error value for ONE IO PAIR. This value is used by
  428.   this routine to determine the learning rate for this particular IO
  429.   pair. The idea (from Don Woods) is to use the cosecant (1/sin) to
  430.   boost the learning rate for both high and low errors. The theory is
  431.   that high errors need lots of change (because they're high) and low
  432.   errors need lots of change (because they're close). I don't know how
  433.   sound that really is; however, NETS is having problems learning when
  434.   using low learn_rates (see T_setup_errors). Hopefully, this change 
  435.   will enable NETS to overcome its precision problems. The change is
  436.   
  437.     new_learn = learn_rate * csc(avg_err * pi)
  438. ----------------------------------------------------------------------
  439. */
  440. BEGIN
  441.    Layer_lst    *cur_layer;
  442.    Weights_lst  *cur_weight;
  443.    Layer        *target, *source;
  444.    Weights      *the_weights;
  445.      
  446.    cur_layer = ptr_net->hidden_back;
  447.    while (cur_layer != NULL) BEGIN
  448.       source     = cur_layer->value;
  449.       cur_weight = cur_layer->value->out_weights;
  450.       while (cur_weight != NULL) BEGIN
  451.          target      = cur_weight->value->target_layer;
  452.          the_weights = cur_weight->value;
  453.          (*the_weights->w_update) (source, target, the_weights);
  454.          cur_weight = cur_weight->next;
  455.       ENDWHILE
  456.       cur_layer = cur_layer->prev;
  457.    ENDWHILE
  458.  
  459. END /* T_change_weights */
  460.